import time
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from copy import deepcopy as dco
import random as rdm

import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import (
    Batch,
    ReplayBuffer,
    ReplayBufferManager,
    VectorReplayBuffer,
    to_numpy,
)
from tianshou.env import BaseVectorEnv, DummyVectorEnv

from fsrl.policy import BasePolicy
from PIL import Image
import matplotlib.pyplot as plt


class FastCollector(object):
    """Collector enables the policy to interact with different types of envs with \
    exact number of episodes.

    This collector is a simplified version of Tianshou's `collector
    <https://tianshou.readthedocs.io/en/master/api/tianshou.data.html#collector>`_, so
    it is safe to check their documentation for details. The main change is the \
        support to extract the cost signals from the interaction data.

    :param policy: an instance of the :class:`~fsrl.policy.BasePolicy` class.
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set
        to None, it will not store the data. Default to None.
    :param function preprocess_fn: a function called before the data has been added to
        the buffer. Default to None.
    :param bool exploration_noise: determine whether the action needs to be modified with
        corresponding policy's exploration noise. If so, "policy. exploration_noise(act,
        batch)" will be called automatically to add the exploration noise into action.
        Default to False.

    .. note::

        Please make sure the given environment has a time limitation (can be done), \
            because we only support the `n_episode` collect option.
    """

    def __init__(
        self,
        policy: BasePolicy,
        env: Union[gym.Env, BaseVectorEnv],
        buffer: Optional[ReplayBuffer] = None,
        preprocess_fn: Optional[Callable[..., Batch]] = None,
        exploration_noise: bool = False,
        shield = None,
        use_shield = False,
        use_shield_distill = False,
        use_fake_data = False,
        fake_next_obs = True,
        fake_done = False,
        fake_rew = -1.0,
        fake_cost = 1.0,
        fake_weight = 1.0,
        shield_prop = 1.0,
        cost_shield_thres = 0.01,
        actor_shield_thres = 0.02,
        action_dim = 4,
        prepare_steps = 2000,
    ) -> None:
        super().__init__()
        if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
            warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
            self.env = DummyVectorEnv([lambda: env])
        else:
            self.env = env
        self.env_num = len(self.env)
        self.exploration_noise = exploration_noise
        self._assign_buffer(buffer)
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        # if isinstance(env, list):
        #     self._action_space = self.env[0].action_space
        # else:
        self._action_space = self.env.action_space
        # avoid creating attribute outside __init__
        self.reset(False)
        self.shield = shield
        self.use_shield = False
        self.use_fake_data = False
        self.use_llm_shield = False
        self.fake_next_obs = fake_next_obs
        self.fake_done = fake_done
        self.fake_rew = fake_rew
        self.fake_cost = fake_cost
        self.fake_weight = fake_weight
        self.shield_prop = shield_prop
        if shield is not None and use_shield:
            self.use_shield = True
        if self.use_shield and use_fake_data:
            self.use_fake_data = True
        if self.use_shield and use_shield_distill:
            self.use_llm_shield = True
        
        self.cost_shield_thres = cost_shield_thres
        self.actor_shield_thres = actor_shield_thres

        self.action_dim = action_dim
        self.prepare_steps = prepare_steps

    def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
        """Check if the buffer matches the constraint."""
        if buffer is None:
            buffer = VectorReplayBuffer(self.env_num, self.env_num)
        elif isinstance(buffer, ReplayBufferManager):
            assert buffer.buffer_num >= self.env_num
        else:  # ReplayBuffer or PrioritizedReplayBuffer
            assert buffer.maxsize > 0
            if self.env_num > 1:
                if isinstance(buffer) == ReplayBuffer:
                    buffer_type = "ReplayBuffer"
                    vector_type = "VectorReplayBuffer"
                else:
                    buffer_type = "PrioritizedReplayBuffer"
                    vector_type = "PrioritizedVectorReplayBuffer"
                raise TypeError(
                    f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect "
                    f"{self.env_num} envs,\n\tplease use {vector_type}(total_size="
                    f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead."
                )
        self.buffer = buffer

    def reset(
        self,
        reset_buffer: bool = True,
        gym_reset_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Reset the environment, statistics, current data and possibly replay memory.

        :param bool reset_buffer: if true, reset the replay buffer that is attached to
            the collector.
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
            reset function. Defaults to None (extra keyword arguments)
        """
        # use empty Batch for "state" so that self.data supports slicing convert empty
        # Batch to None when passing data to policy
        self.data = Batch(
            obs={},
            act={},
            rew={},
            cost={},
            terminated={},
            truncated={},
            done={},
            obs_next={},
            info={},
            policy={}
        )
        self.reset_env(gym_reset_kwargs)
        if reset_buffer:
            self.reset_buffer()
        self.reset_stat()

    def reset_stat(self) -> None:
        """Reset the statistic variables."""
        self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0

    def reset_buffer(self, keep_statistics: bool = False) -> None:
        """Reset the data buffer."""
        self.buffer.reset(keep_statistics=keep_statistics)

    def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
        """Reset all of the environments."""
        gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
        rval = self.env.reset(**gym_reset_kwargs)
        returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
            isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
        )
        if returns_info:
            obs, info = rval
            if self.preprocess_fn:
                processed_data = self.preprocess_fn(
                    obs=obs, info=info, env_id=np.arange(self.env_num)
                )
                obs = processed_data.get("obs", obs)
                info = processed_data.get("info", info)
            self.data.info = info
        else:
            obs = rval
            if self.preprocess_fn:
                obs = self.preprocess_fn(obs=obs,
                                         env_id=np.arange(self.env_num)).get("obs", obs)
        self.data.obs = obs

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset the hidden state: self.data.state[id]."""
        if hasattr(self.data.policy, "hidden_state"):
            state = self.data.policy.hidden_state  # it is a reference
            if isinstance(state, torch.Tensor):
                state[id].zero_()
            elif isinstance(state, np.ndarray):
                state[id] = None if state.dtype == object else 0
            elif isinstance(state, Batch):
                state.empty_(id)

    def _reset_env_with_ids(
        self,
        local_ids: Union[List[int], np.ndarray],
        global_ids: Union[List[int], np.ndarray],
        gym_reset_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
        rval = self.env.reset(global_ids, **gym_reset_kwargs)
        returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
            isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
        )
        if returns_info:
            obs_reset, info = rval
            if self.preprocess_fn:
                processed_data = self.preprocess_fn(
                    obs=obs_reset, info=info, env_id=global_ids
                )
                obs_reset = processed_data.get("obs", obs_reset)
                info = processed_data.get("info", info)
            self.data.info[local_ids] = info
        else:
            obs_reset = rval
            if self.preprocess_fn:
                obs_reset = self.preprocess_fn(obs=obs_reset,
                                               env_id=global_ids).get("obs", obs_reset)
        self.data.obs_next[local_ids] = obs_reset

    def collect(
        self,
        n_episode: int = 1,
        random: bool = False,
        render: bool = False,
        no_grad: bool = True,
        gym_reset_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Collect a specified number of step or episode.

        To ensure unbiased sampling result with n_episode option, this function will
        first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
        episodes, they will be collected evenly from each env.

        :param int n_episode: how many episodes you want to collect.
        :param bool random: whether to use random policy for collecting data. Default to
            False.
        :param bool render: Whether to render the environment during evaluation, defaults
            to False
        :param bool no_grad: whether to retain gradient in policy.forward(). Default to
            True (no gradient retaining).
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
            reset function. Defaults to None (extra keyword arguments)

        .. note::

            We don not support the `n_step` collection method in Tianshou, because using
            `n_episode` only can facilitate the episodic cost computation and better
            evaluate the agent.

        :return: A dict including the following keys

            * ``n/ep`` collected number of episodes.
            * ``n/st`` collected number of steps.
            * ``rew`` mean of episodic rewards.
            * ``len`` mean of episodic lengths.
            * ``total_cost`` cumulative costs in this collect.
            * ``cost`` mean of episodic costs.
            * ``truncated`` mean of episodic truncation.
            * ``terminated`` mean of episodic termination.
        """
        if n_episode is not None:
            assert n_episode > 0
            ready_env_ids = np.arange(min(self.env_num, n_episode))
            self.data = self.data[:min(self.env_num, n_episode)]
        else:
            raise TypeError("Please specify n_episode"
                            "in FastCollector.collect().")

        start_time = time.time()

        step_count = 0
        total_cost = 0
        total_cost_fake = 0
        termination_count = 0
        truncation_count = 0
        episode_count = 0
        episode_rews = []
        episode_lens = []
        episode_start_indices = []

        while True:
            assert len(self.data) == len(ready_env_ids)
            # restore the state: if the last state is None, it won't store
            last_state = self.data.policy.pop("hidden_state", None)
            # print(self.data.obs.shape)
            # if self.use_fake_data:
            #     obs_org = dco(self.data.obs)

            # get the next action
            if random:
                try:
                    act_sample = [self._action_space[i].sample() for i in ready_env_ids]
                except TypeError:  # envpool's action space is not for per-env
                    act_sample = [self._action_space.sample() for _ in ready_env_ids]
                act_sample = self.policy.map_action_inverse(act_sample)  # type: ignore
                self.data.update(act=act_sample)
            else:
                if no_grad:
                    with torch.no_grad():  # faster than retain_grad version
                        # self.data.obs will be used by agent to get result
                        result = self.policy(self.data, last_state)
                else:
                    result = self.policy(self.data, last_state)
                # update state / act / policy into self.data
                policy = result.get("policy", Batch())
                assert isinstance(policy, Batch)
                state = result.get("state", None)
                if state is not None:
                    policy.hidden_state = state  # save state into buffer
                act = to_numpy(result.act)
                if self.exploration_noise:
                    act = self.policy.exploration_noise(act, self.data)
                self.data.update(policy=policy, act=act)

            # get bounded and remapped actions first (not saved into buffer)
            add_fake_data = False
            fake_ids = []
            distill_ids = {}
            # if self.use_fake_data:
            act_real = dco(act)
            use_shield = self.use_shield
            use_fake_data = self.use_fake_data
            if use_shield:
                p = rdm.random()
                if p > self.shield_prop:
                    use_shield = False
                    use_fake_data = False


            if use_shield:
                for i in range(len(self.data.obs)):
                    unc = 1.0
                    distill_out = 0
                    if self.use_llm_shield:
                        with torch.no_grad():
                            cost_shield_distill = self.policy.cost_shield_distill
                            obs_input = torch.tensor(self.data.obs[i].reshape(1,-1), dtype=torch.float32, device=self.policy.device)
                            action_input = torch.tensor(self.data.act[i].reshape(1,), dtype=torch.long, device=self.policy.device)
                            action_input = F.one_hot(action_input, num_classes=self.action_dim).to(torch.float32).to(self.policy.device)
                            # action_input_rep = action_input.unsqueeze(0).repeat(cost_shield_distill.num_ensemble, 1, 1)
                            # action_input_rep_rev = (1 - action_input_rep)/(self.action_dim-1)
                            cost_shield_out = cost_shield_distill(obs_input, act=action_input).squeeze()
                            unc = cost_shield_out.var(0).mean()
                            # print(unc)
                            cost_pre_shield = (cost_shield_out.mean(0) > 0.5)
                    if self.use_llm_shield and unc < self.cost_shield_thres and self.collect_step >= self.prepare_steps:
                        cost_pre = cost_pre_shield
                        if cost_pre:
                            distill_out += 1
                    else:
                        cost_pre = self.shield.get_cost(self.data.obs[i], self.data.act[i])
                    if cost_pre:
                        add_fake_data = True
                        fake_ids.append(i)
                        unc_actor = 1.0
                        if self.use_llm_shield:
                            with torch.no_grad():
                                actor_shield_distill = self.policy.actor_shield_distill
                                actor_shield_out = actor_shield_distill(obs_input).squeeze()
                                unc_actor = actor_shield_out.var(0).mean()
                                # print(unc_actor)
                                actor_shield_pred = actor_shield_out.mean(0).argmax()
                        if self.use_llm_shield and unc_actor < self.actor_shield_thres and self.collect_step >= self.prepare_steps:
                            safe_action = actor_shield_pred
                            distill_out += 2
                        else:
                            safe_action = self.shield.get_action(self.data.obs[i])
                        # if self.use_fake_data:
                        act_real[i] = safe_action
                        # else:
                        #     self.data.act[i] = safe_action
                        distill_ids[i] = distill_out


            if use_shield:
                action_remap = self.policy.map_action(act_real)
            else:
                action_remap = self.policy.map_action(self.data.act)
            # step in env
            # print(action_remap.shape)
            # print(len(ready_env_ids))
            result = self.env.step(action_remap, ready_env_ids)
            if len(result) == 5:
                obs_next, rew, terminated, truncated, info = result
                done = np.logical_or(terminated, truncated)
            elif len(result) == 4:
                obs_next, rew, done, info = result
                info["weight"] = 1.0
                if isinstance(info, dict):
                    truncated = info["TimeLimit.truncated"]
                else:
                    truncated = np.array(
                        [
                            info_item.get("TimeLimit.truncated", False)
                            for info_item in info
                        ]
                    )
                terminated = np.logical_and(done, ~truncated)
            else:
                raise ValueError()
            
            self.data.update(info=info)
            info_org = dco(self.data.info)

            for idx in range(len(info)):
                info[idx]["weight"] = 1.0
                info[idx]["distill_cost_weight"] = 1.0
                info[idx]["distill_actor_weight"] = 1.0
            
            # print(obs_next.shape, rew.shape, terminated.shape, truncated.shape, done.shape, info)
            # assert False
            if use_fake_data:
                # obs_next_fake 保持与obs相同，使得当前sample实际不会造成obs的变化
                if self.fake_next_obs:
                    obs_next_fake = dco(self.data.obs)
                else:
                    obs_next_fake = obs_next
                rew_fake = dco(rew)
                terminated_fake = dco(terminated)
                truncated_fake = dco(truncated)
                done_fake = dco(done)
                info_fake = dco(info)
            if add_fake_data:
                for idx in fake_ids:
                    if use_fake_data:
                        rew_fake[idx] = self.fake_rew
                        terminated_fake[idx] = 0.0
                        truncated_fake[idx] = 0.0
                        done_fake[idx] = 0.0
                        if self.fake_done:
                            terminated_fake[idx] = 1.0
                            truncated_fake[idx] = 1.0
                            done_fake[idx] = 1.0
                        info_fake[idx]['cost'] = self.fake_cost
                        info_fake[idx]["weight"] = self.fake_weight
                        info_fake[idx]["distill_cost_weight"] = float((distill_ids[idx] == 0) or (distill_ids[idx] == 2))
                        info_fake[idx]["distill_actor_weight"] = float((distill_ids[idx] == 0) or (distill_ids[idx] == 1))
                    else:
                        rew[idx] = self.fake_rew
                        info[idx]['cost'] = self.fake_cost
            if use_fake_data:
                ### fake data: org_action, policy, obs_next_fake, rew_fake, done_fake
                self.data.update(
                    obs_next=obs_next_fake,
                    rew=rew_fake,
                    terminated=terminated_fake,
                    truncated=truncated_fake,
                    done=done_fake,
                    info=info_fake
                )
                if self.preprocess_fn:
                    assert False
                cost = info_org.get("cost", np.zeros(rew.shape))
                cost_fake = dco(cost)
                total_cost += np.sum(cost)
                if add_fake_data:
                    for idx in fake_ids:
                        cost_fake[idx] = self.fake_cost
                total_cost_fake += np.sum(cost_fake)
                self.data.update(cost=cost_fake)
                if len(fake_ids) > 0:
                    ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
                        self.data[fake_ids], buffer_ids=ready_env_ids[fake_ids]
                    )
                ### real data: act_real, ...
                self.data.update(policy=policy, act=act_real)
                self.data.update(
                    obs_next=obs_next,
                    rew=rew,
                    terminated=terminated,
                    truncated=truncated,
                    done=done,
                    info=info
                )
                if self.preprocess_fn:
                    assert False
                self.data.update(cost=cost)
            else:
                self.data.update(
                    obs_next=obs_next,
                    rew=rew,
                    terminated=terminated,
                    truncated=truncated,
                    done=done,
                    info=info
                )
                if self.preprocess_fn:
                    assert False
                    self.data.update(
                        self.preprocess_fn(
                            obs_next=self.data.obs_next,
                            rew=self.data.rew,
                            done=self.data.done,
                            info=self.data.info,
                            policy=self.data.policy,
                            env_id=ready_env_ids,
                        )
                    )

                cost = info_org.get("cost", np.zeros(rew.shape))
                total_cost += np.sum(cost)
                if add_fake_data:
                    for idx in fake_ids:
                        cost[idx] = self.fake_cost
                total_cost_fake += np.sum(cost)
                self.data.update(cost=cost)

            if render:
                self.env.render()

            # add data into the buffer
            ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
                self.data, buffer_ids=ready_env_ids
            )

            # collect statistics
            step_count += len(ready_env_ids)

            if np.any(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                episode_count += len(env_ind_local)
                episode_lens.append(ep_len[env_ind_local])
                episode_rews.append(ep_rew[env_ind_local])
                episode_start_indices.append(ep_idx[env_ind_local])
                termination_count += np.sum(terminated)
                truncation_count += np.sum(truncated)
                # now we copy obs_next to obs, but since there might be finished
                # episodes, we have to reset finished envs first.
                self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs)
                for i in env_ind_local:
                    self._reset_state(i)

                # remove surplus env id from ready_env_ids to avoid bias in selecting
                # environments
                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]

            self.data.obs = self.data.obs_next

            if n_episode and episode_count >= n_episode:
                break

        # generate statistics
        self.collect_step += step_count
        self.collect_episode += episode_count
        self.collect_time += max(time.time() - start_time, 1e-9)

        if n_episode:
            self.data = Batch(
                obs={},
                act={},
                rew={},
                cost={},
                terminated={},
                truncated={},
                done={},
                obs_next={},
                info={},
                policy={}
            )
            self.reset_env()

        if episode_count > 0:
            rews, lens = list(map(np.concatenate, [episode_rews, episode_lens]))
            rew_mean = rews.mean()
            len_mean = lens.mean()
        else:
            rew_mean = len_mean = 0

        done_count = termination_count + truncation_count

        llm_call = 0
        if self.use_shield:
            llm_call = self.shield.total_call

        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "rew": rew_mean,
            "len": len_mean,
            "total_cost": total_cost,
            "cost": total_cost_fake / episode_count,
            "truncated": truncation_count / done_count,
            "terminated": termination_count / done_count,
            "llm_call": llm_call,
        }
